{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree Mapping with SAMGeo and Segment Anything Model 2 (SAM 2)\n",
    "\n",
    "[![image](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/opengeos/segment-geospatial/blob/main/docs/examples/tree_mapping.ipynb)\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/tree_mapping.ipynb)\n",
    "\n",
    "This notebook shows how to segment trees from aerial imagery with the Segment Anything Model 2 (SAM 2). \n",
    "\n",
    "Make sure you use GPU runtime for this notebook. For Google Colab, go to `Runtime` -> `Change runtime type` and select `GPU` as the hardware accelerator. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies\n",
    "\n",
    "Uncomment and run the following cell to install the required dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install segment-geospatial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import leafmap\n",
    "from samgeo import SamGeo2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create an interactive map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map(center=[-22.17615, -51.253043], zoom=18, height=\"800px\")\n",
    "m.add_basemap(\"SATELLITE\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download a sample image\n",
    "\n",
    "Pan and zoom the map to select the area of interest. Use the draw tools to draw a polygon or rectangle on the map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bbox = m.user_roi_bounds()\n",
    "if bbox is None:\n",
    "    bbox = [-51.2565, -22.1777, -51.2512, -22.175]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = \"Image.tif\"\n",
    "leafmap.map_tiles_to_geotiff(\n",
    "    output=image, bbox=bbox, zoom=19, source=\"Satellite\", overwrite=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use your own image. Uncomment and run the following cell to use your own image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = '/path/to/your/own/image.tif'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the downloaded image on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.layers[-1].visible = False\n",
    "m.add_raster(image, layer_name=\"Image\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize SAM class\n",
    "\n",
    "Set `automatic=False` to enable the `SAM2ImagePredictor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam = SamGeo2(\n",
    "    model_id=\"sam2-hiera-large\",\n",
    "    automatic=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specify the image to segment. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam.set_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the map. Use the drawing tools to draw some rectangles around the features you want to extract, such as trees, buildings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create bounding boxes\n",
    "\n",
    "If no rectangles are drawn, the default bounding boxes will be used as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if m.user_rois is not None:\n",
    "    boxes = m.user_rois\n",
    "else:\n",
    "    boxes = [\n",
    "        [-51.2546, -22.1771, -51.2541, -22.1767],\n",
    "        [-51.2538, -22.1764, -51.2535, -22.1761],\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segment the image\n",
    "\n",
    "Use the `predict()` method to segment the image with specified bounding boxes. The `boxes` parameter accepts a list of bounding box coordinates in the format of [[left, bottom, right, top], [left, bottom, right, top], ...], a GeoJSON dictionary, or a file path to a GeoJSON file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam.predict(boxes=boxes, point_crs=\"EPSG:4326\", output=\"mask.tif\", dtype=\"uint8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display the result\n",
    "\n",
    "Add the segmented image to the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.add_raster(\"mask.tif\", cmap=\"viridis\", nodata=0, layer_name=\"Mask\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use an existing vector dataset as box prompts\n",
    "\n",
    "You can also use an existing vector dataset as box prompts. The following example uses an existing dataset of tree bounding boxes from GitHub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geojson = (\n",
    "    \"https://github.com/opengeos/datasets/releases/download/samgeo/tree_boxes.geojson\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the bounding boxes on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map()\n",
    "m.add_raster(image, layer_name=\"image\")\n",
    "style = {\n",
    "    \"color\": \"#ffff00\",\n",
    "    \"weight\": 2,\n",
    "    \"fillColor\": \"#7c4185\",\n",
    "    \"fillOpacity\": 0,\n",
    "}\n",
    "m.add_vector(\n",
    "    geojson,\n",
    "    style=style,\n",
    "    zoom_to_layer=True,\n",
    "    layer_name=\"Bounding boxes\",\n",
    "    info_mode=None,\n",
    ")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segment trees with box prompts\n",
    "\n",
    "Segment trees using the bounding boxes from the vector dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_masks = \"mask2.tif\"\n",
    "sam.predict(boxes=geojson, point_crs=\"EPSG:4326\", output=output_masks, dtype=\"uint8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the segmented masks on the map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.add_raster(output_masks, nodata=0, opacity=0.5, layer_name=\"Tree masks\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post-processing\n",
    "\n",
    "You can use the `region_groups()` method to clean up the segmentation results, such as removing small regions, and filling holes. In addition, you can compute geometric properties of the regions, such as area, perimeter, eccentricity, and solidity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_image = \"tree_masks.tif\"\n",
    "out_vector = \"tree_vector.geojson\"\n",
    "array, gdf = sam.region_groups(\n",
    "    output_masks, min_size=200, out_vector=out_vector, out_image=out_image\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gdf.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display the cleaned masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map()\n",
    "m.add_raster(image, layer_name=\"Image\")\n",
    "style = {\n",
    "    \"color\": \"#ffff00\",\n",
    "    \"weight\": 2,\n",
    "    \"fillColor\": \"#7c4185\",\n",
    "    \"fillOpacity\": 0,\n",
    "}\n",
    "m.add_raster(\n",
    "    out_image, colormap=\"tab20\", nodata=0, opacity=0.7, layer_name=\"Tree masks\"\n",
    ")\n",
    "m.add_vector(out_vector, style=style, zoom_to_layer=True, layer_name=\"Tree vector\")\n",
    "m.add_vector(\n",
    "    geojson,\n",
    "    style={\"color\": \"blue\", \"fillOpacity\": 0},\n",
    "    layer_name=\"Bounding boxes\",\n",
    "    info_mode=None,\n",
    ")\n",
    "m.add_layer_manager()\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/b789a0e6-6e76-4b10-a9b8-3fc14676481f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a split map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map()\n",
    "m.add_raster(image, layer_name=\"Image\")\n",
    "m.split_map(\n",
    "    out_image,\n",
    "    image,\n",
    "    left_label=\"Tree masks\",\n",
    "    right_label=\"Aerial imagery\",\n",
    "    left_args={\"colormap\": \"tab20\", \"nodata\": 0, \"opacity\": 0.7},\n",
    ")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![demo](https://github.com/user-attachments/assets/7bb0a65c-94f1-4cb6-9361-e79b47ec1e0a)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
